{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch ONNX Relax 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import fx\n",
    "import torchvision\n",
    "\n",
    "import tvm\n",
    "from tvm import relax\n",
    "import tvm.testing\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "from tvm.script import tir as T\n",
    "from tvm.relax.frontend import detach_params\n",
    "from tvm.relax.frontend.torch import from_fx\n",
    "\n",
    "\n",
    "def verify_model(torch_model, input_info, binding, expected):\n",
    "    graph_model = fx.symbolic_trace(torch_model)\n",
    "    with torch.no_grad():\n",
    "        mod = from_fx(graph_model, input_info)\n",
    "    binding = {k: tvm.nd.array(v) for k, v in binding.items()}\n",
    "    expected = relax.transform.BindParams(\"main\", binding)(expected)\n",
    "    tvm.ir.assert_structural_equal(mod, expected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class M(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)\n",
    "        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = F.interpolate(\n",
    "            x,\n",
    "            size=None,\n",
    "            scale_factor=(0.5, 0.5),\n",
    "            mode=\"nearest\",\n",
    "            # align_corners=False,\n",
    "        )\n",
    "        x = x * 4.019027233123779\n",
    "        x = self.conv2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "torch_model = M()\n",
    "graph_model = fx.symbolic_trace(torch_model)\n",
    "input_info = [([1, 3, 10, 10], \"float32\")]\n",
    "with torch.no_grad():\n",
    "    mod = from_fx(graph_model, input_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relax.dpl import *\n",
    "\n",
    "@tvm.ir.transform.module_pass(opt_level=0, name=\"MulConv2dRewriter\")\n",
    "class MulConv2dRewriterPipeline:\n",
    "    def transform_module(self, mod, ctx):\n",
    "        x = wildcard()\n",
    "        scale = is_const()\n",
    "        multiply = is_op(\"relax.multiply\")(x, scale)\n",
    "        weight = is_const()\n",
    "        pattern = is_op(\"relax.nn.conv2d\")(multiply, weight)\n",
    "\n",
    "        def rewriter(_, matches):\n",
    "            x_ = matches[x]\n",
    "            w_ = matches[weight]\n",
    "            w_ = w_ * matches[scale]\n",
    "            o = R.nn.conv2d(x_, w_)\n",
    "            return o\n",
    "\n",
    "        mod[\"main\"] = rewrite_call(pattern, rewriter, mod[\"main\"])\n",
    "        mod = relax.transform.FoldConstant()(mod)\n",
    "        return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(inp_0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(inp_0, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>image<span style=\"color: #AA22FF; font-weight: bold\">.</span>resize2d(lv, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([<span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>]), roi<span style=\"color: #AA22FF; font-weight: bold\">=</span>[T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)], layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, method<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;nearest_neighbor&quot;</span>, coordinate_transformation_mode<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;asymmetric&quot;</span>, rounding_method<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;round&quot;</span>, cubic_alpha<span style=\"color: #AA22FF; font-weight: bold\">=-</span><span style=\"color: #008000\">0.5</span>, cubic_exclude<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">0</span>, extrapolation_value<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">0.0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">32</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(lv1, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">1</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "ori_mod = deepcopy(mod)\n",
    "mod = MulConv2dRewriterPipeline()(mod)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvm_model = relax.transform.DecomposeOpsForInference()(mod)\n",
    "# Legalize any relax ops into tensorir.\n",
    "tvm_model = relax.transform.LegalizeOps()(tvm_model)\n",
    "\n",
    "# Separate model from parameters.\n",
    "tvm_model, params = relax.frontend.detach_params(tvm_model)\n",
    "# Compile the relax graph into a VM then run.\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    ex = tvm.compile(tvm_model, target=\"llvm\")\n",
    "    vm = relax.VirtualMachine(ex, tvm.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "dev = tvm.cpu()\n",
    "inputs_np = [np.random.rand(1, 3, 10, 10).astype(\"float32\")]\n",
    "inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]\n",
    "# Run model and check outputs.\n",
    "vm.set_input(\"main\", *inputs)\n",
    "vm.invoke_stateful(\"main\")\n",
    "tvm_output = vm.get_outputs(\"main\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    torch_output = torch_model(torch.from_numpy(inputs_np[0]))\n",
    "np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
